6.2 Episodic Generalization Optimization - EGO#

Introduction#

Human cognition is unique in its ability to perform a wide range of tasks and to learn new tasks quickly. Both abilities have long been associated with the acquisition of knowledge that can generalize across tasks and the flexible use of that knowledge to execute goal-directed behavior. In this tutorial, we introduce how this can emerge in a neural network by implementing the Episodic Generalization and Optimization (EGO) framework. The framework consists of an episodic memory module, which rapidly learns relationships between stimuli; a semantic pathway, which more slowly learns how stimuli map to responses; and a recurrent context module, which maintains a representation of task-relevant context information, integrates this over time, and uses it to recall context-relevant memories.

EGO

The EGO framework consists of a control mechanism (context module; upper middle) and an episodic memory mechanism (bottom left). Episodic memory records conjunctions of stimuli (blue boxes), contexts (pink boxes), and observed responses (green boxes) at each time point (rows). Bidirectional arrows connect episodic memory to the stimulus, context, and output, indicating that these values can be stored in or used to query episodic memory, or retrieved from it when another field is queried. You can think of this as a more flexible dictionary that stores triplets instead of distinct key-value pairs, and allows any field (or any combinations of fields) to act as a key. The context module integrates previous context (recurrent connection) along with information about the stimulus and the context retrieved from memory.

Here we show that the EGO framework can emulate human behavior in a specific learning environment where participants are trained on two sets of sequences involving identical states presented in different orders for different contexts. Empirical findings show that participants perform better when trained in blocks of each context than when trained interleaved:

Task: Coffe Shop World (CSW)#

suspicious barrista caffe graditude

Imagine, you are in a city with two coffee shops, each with a different layout and different ways of ordering. In one coffee shop—called The Suspicious Barista—you order first, pay for the coffee, and then sit down to wait until the waiter brings your order. In the other coffee shop—called Café Gratitude—you sit down first, wait until the waiter comes and takes your order. You pay after finishing the coffee.

This example demonstrates that many situations share similar stimuli but have different transition structures. Simple integration will help the system learn the transition structure, but it will only provide a weak cue about the difference between them due to the similarity between the situations. In other words the states –ordering, paying, and sitting down– are very similar between the two situations and are therefore hard to distinguish. This can be overcome by differentiating the context representations associated with each setting (e.g., learning different context representations for coffee shops with paranoid vs. gullible baristas). Recent empirical work suggests that people can learn how to do this very effectively, but that this depends on the temporal structure of the environment: people do better when trained in blocks of each situation than when trained interleaved (Beukers et al., 2023).

We start with creating a dataset for the CSW task.

Installation and Setup

If the following cell fails to execute, please restart the kernel (or session) and run the cell again. This is a known issue when running in google colab.

%%capture
%pip install psyneulink

import psyneulink as pnl
import random

Generating data for the CSW task#

We start by generating a dataset for the CSW task. The dataset consists of sequences of states. The task is to predict the next state given the current state and the context. The transition between states is determined by the context which in turn is determined by the “first” state in the sequence. The following figure illustrates the task structure:

EGO

On the left side of the figure, you can see the task structure:

The two colors represent different contexts: blue and orange.

  • If the first observed state in a sequence is 0, the participant is in the blue context.

    • The next state can be either 1 or 2.

    • From then on, transitions are deterministic:

      • 1 → 3 → 5 → 7

      • 2 → 4 → 6 → 8

  • If the first observed state is 9, the participant is in the orange context.

    • The sequence starts with either 1 or 2, but follows a different transition pattern:

      • 1 → 4 → 5 → 8

      • 2 → 3 → 6 → 7

The right side of the figure shows the different learning paradigms:

In the blocked paradigm, participants are trained on blocks of the same context. In the interleaved paradigm, participants are trained on a mix of contexts. In the test paradigm, participants are tested on a sequence of random contexts.

We start with defining a function that generates a context-specific sequence:

def gen_context(
    context: int,
    start_state: int,
):
    """
    Generate a context-specific sequence.
    Args:
        context (int): The context to generate the sequence for. (0 or 9)
        start_state (int): The first state in the sequence. (1 or 2)
    """
    seq = [context, start_state]
    if context == 0:
        for _ in range(3):
            seq.append(seq[-1] + 2)
    elif context == 9:
        for _ in range(3):
            seq.append(seq[-1] + 1 if seq[-1] % 2 == 0 else seq[-1] + 3)
    return seq

"""Test the function"""
assert gen_context(0, 1) == [0, 1, 3, 5, 7]
assert gen_context(9, 2) == [9, 2, 3, 6, 7]

Generate a full dataset for the CSW task. Now, let’s create a function that returns the full trial sequence for a given paradigm and number of samples.

# Define the paradigms
BLOCKED = 'blocked'
INTERLEAVED = 'interleaved'


def gen_context_sequences(
        paradigm: str,
        train_contexts: int,
        test_contexts: int,
        block_size: int = 4,
):
    """
    Generate a dataset for the CSW task.
    Args:
        paradigm (str): The paradigm to generate the dataset for. (blocked or interleaved)
        train_contexts (int): The number of training contexts.
        test_contexts (int): The number of test contexts.
        block_size (int): The size of each block in the blocked paradigm.
    """
    assert train_contexts % block_size == 0, "The number of training samples must be a multiple of block_size."
    x = []
    if paradigm == INTERLEAVED:
        for idx in range(train_contexts):
            if idx % 2: # odd contexts -> context 0
                x += [gen_context(0, random.randint(1, 2))]
            else: # even contexts -> context 9
                x += [gen_context(9, random.randint(1, 2))]

    if paradigm == BLOCKED:
        for i in range(block_size): # block_size number of blocks
            if i % 2: # odd blocks -> context 0
                for _ in range(train_contexts // block_size):
                    x += [gen_context(0, random.randint(1, 2))]
            else: # even blocks -> context 9
                for _ in range(train_contexts // block_size):
                    x += [gen_context(9, random.randint(1, 2))]

    for _ in range(test_contexts):
        x += [gen_context(random.choice([0, 9]), random.randint(1, 2))]
    return x


context_sequences = gen_context_sequences(BLOCKED, 8, 4)
context_sequences
[[9, 2, 3, 6, 7],
 [9, 1, 4, 5, 8],
 [0, 1, 3, 5, 7],
 [0, 2, 4, 6, 8],
 [9, 2, 3, 6, 7],
 [9, 1, 4, 5, 8],
 [0, 1, 3, 5, 7],
 [0, 2, 4, 6, 8],
 [9, 1, 4, 5, 8],
 [9, 2, 3, 6, 7],
 [9, 1, 4, 5, 8],
 [9, 2, 3, 6, 7]]

The structure of the generated sequence is not “realistic” yet. The participant doesn’t see distinct contexts but rather states. We need to “flatten” the sequence. Also, we instead of using integers to represent the states, we will use one-hot encoding:

def one_hot_encode(
        label: int,
        num_classes: int):
    """
    One hot encode a label (integer)
    Args:
        label (int): The label to encode (between 0 and num_classes-1)
        num_classes (int): The number of classes
    """
    return [1 if i == label else 0 for i in range(num_classes)]


def state_sequence(
        paradigm: str,
        train_trials: int,
        test_trials: int,
        context_length: int = 5,
        block_size: int = 4,
):
    """
    Generate a dataset for the CSW task.
    Args:
        paradigm (str): The paradigm to generate the dataset for. (blocked or interleaved)
        train_trials (int): The number of training trials.
        test_trials (int): The number of test trials.
        context_length (int): The length of the context.
        block_size (int): The size of each block in the blocked paradigm.
    """


    assert train_trials % context_length == 0, "The number of training samples must be a multiple of context_length."
    assert test_trials % context_length == 0, "The number of test samples must be a multiple of context_length."

    train_contexts = train_trials // context_length
    test_contexts = test_trials // context_length

    train_context_sequences = gen_context_sequences(
        paradigm, train_contexts, test_contexts, block_size
    )

    states = []
    for context_sequence in train_context_sequences:
        for state_int in context_sequence:
            states.append(one_hot_encode(state_int, 11))
    return states


state_sequences = state_sequence(BLOCKED, 20, 5)
state_sequences
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]]

🎯 Exercise 1

Why do we encode the states using one-hot encoding?

✅ Solution 1

One-hot encoding is used for categorical variables. This means states have no inherit “order” or can be compared using arithmetic operations in a meaningful way. One-hot encoding allows this representation as states are “orthogonal” to each other.

🎯 Exercise 2

We want to train the EGO model in a supervised manner but the generated dataset doesn’t allow us to do so. Why is this the case and what do we need to do be able to train the model?

💡 Hint 1

For supervised training, we need to provide a target for each input. Think about what the target should be in this case.

💡 Hint 2

The task in this case, is to predict the next state given the current state.

✅ Solution 2

The target in this case is just the next state in the sequence:

x = state_sequence(BLOCKED, 20, 5)
y = x[1:] + [one_hot_encode(0, 11)] # the last state has no next state and is arbitrary in this case either 0 or 9

The EGO model#

As mentioned earlier, the EGO model consists of three main components: an episodic memory module, a semantic pathway, and a recurrent context module. PsyNeulink provides a EMComposition class that allows us to create the episodic memory module. The EMcomposition class is a subclass of the Composition class. A strength of the PsyNeuLink framework is that it allows fo the creation of complex composition that can be used as mechanism in other compositions. Here, we first look at the EMComposition class in isolation and then integrate it into the EGO model.

Episodic Memory Module - EMComposition#

EM

Here, we initialize the EMComposition for the episodic memory shown above. The EMComposition allows for specifying the structure of the episodic memory. Remember, the task here is to predict the state from the previous state and the context. Therefore, in our case each entry in the memory consists of a triplet of states:

  • The current state (green box)

  • The previous state (blue box)

  • The context (pink box)

Each state is represented as a vector with 11 elements (one hot encoding).

Here, we also specify the specific fields. Fields have three main parameters that have to be specified as a dictionary:

  • FIELD_WEIGHT: The weight of the field when retrieving from memory

  • LEARN_FIELD_WEIGHT: Whether the retrieval field weight should be learned (Here, we won’t learn these weights but set them)

  • TARGET_FIELD: Whether the field is a target field (Meaning it’s “error” is calculated during learning)

🎯 Exercise 3

Before looking at the code below, think about what to set for the FIELD_WEIGHT and the TARGET_FIELD for the three different fields (current state, previous state, and context).

💡 Hint

The FIELD_WEIGHT specifies weather a field should be used during retrieval (and how much it should be used during retrieval). It is a scalar value between 0 and 1. The TARGET_FIELD specifies weather a field is a target field.

✅ Solution

The FIELD_WEIGHT for the current state should be None since it is the target field and shouldn’t be used in retrieval. The FIELD_WEIGHT for both the previous and the context should be set to an equal value (here we set them both to 1). The TARGET_FIELD should be set to True for the current state and False for the previous state and the context.

name = 'EM'  # a name for the EMComposition

# Memory parameters
state_size = 11  # the size of the state vector
memory_capacity = 1000  # here we set the maximum number of entries in the memory (we want to be able to store all 1000 trials)

# Fields

# State field
state_name = 'STATE'
state_retrieval_weight = None  # This entry is not used when retrieving from memory (remember, we want to predict the state)
state_is_target = True

# Previous state field
previous_state_name = 'PREVIOUS STATE'
previous_state_retrieval_weight = .5  # This entry is used when retrieving from memory
previous_state_is_target = False

# Context field
context_name = 'CONTEXT'
context_retrieval_weight = .5  # This entry is used when retrieving from memory
context_is_target = False

em = pnl.EMComposition(name=name,
                       memory_template=[[0] * state_size,  # state
                                        [0] * state_size,  # previous state
                                        [0] * state_size],  # context
                       memory_fill=.001,
                       memory_capacity=memory_capacity,
                       normalize_memories=False,
                       memory_decay_rate=0,  # no decay of memory
                       softmax_gain=10.,
                       softmax_threshold=.001,
                       fields={state_name: {pnl.FIELD_WEIGHT: state_retrieval_weight,
                                            pnl.LEARN_FIELD_WEIGHT: False,
                                            pnl.TARGET_FIELD: True},
                               previous_state_name: {pnl.FIELD_WEIGHT: previous_state_retrieval_weight,
                                                     pnl.LEARN_FIELD_WEIGHT: False,
                                                     pnl.TARGET_FIELD: False},
                               context_name: {pnl.FIELD_WEIGHT: context_retrieval_weight,
                                              pnl.LEARN_FIELD_WEIGHT: False,
                                              pnl.TARGET_FIELD: False}},

                       normalize_field_weights=True,

                       concatenate_queries=False,
                       enable_learning=True,
                       learning_rate=.5,
                       device=pnl.CPU
                       )

Let’s see how the EMComposition looks like:

em.show_graph(output_fmt='jupyter')
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/execute.py:76, in run_check(cmd, input_lines, encoding, quiet, **kwargs)
     75         kwargs['stdout'] = kwargs['stderr'] = subprocess.PIPE
---> 76     proc = _run_input_lines(cmd, input_lines, kwargs=kwargs)
     77 else:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/execute.py:96, in _run_input_lines(cmd, input_lines, kwargs)
     95 def _run_input_lines(cmd, input_lines, *, kwargs):
---> 96     popen = subprocess.Popen(cmd, stdin=subprocess.PIPE, **kwargs)
     98     stdin_write = popen.stdin.write

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/subprocess.py:1026, in Popen.__init__(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds, user, group, extra_groups, encoding, errors, text, umask, pipesize, process_group)
   1023             self.stderr = io.TextIOWrapper(self.stderr,
   1024                     encoding=encoding, errors=errors)
-> 1026     self._execute_child(args, executable, preexec_fn, close_fds,
   1027                         pass_fds, cwd, env,
   1028                         startupinfo, creationflags, shell,
   1029                         p2cread, p2cwrite,
   1030                         c2pread, c2pwrite,
   1031                         errread, errwrite,
   1032                         restore_signals,
   1033                         gid, gids, uid, umask,
   1034                         start_new_session, process_group)
   1035 except:
   1036     # Cleanup if the child failed starting.

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/subprocess.py:1955, in Popen._execute_child(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, gid, gids, uid, umask, start_new_session, process_group)
   1954 if err_filename is not None:
-> 1955     raise child_exception_type(errno_num, err_msg, err_filename)
   1956 else:

FileNotFoundError: [Errno 2] No such file or directory: PosixPath('dot')

The above exception was the direct cause of the following exception:

ExecutableNotFound                        Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/IPython/core/formatters.py:1036, in MimeBundleFormatter.__call__(self, obj, include, exclude)
   1033     method = get_real_method(obj, self.print_method)
   1035     if method is not None:
-> 1036         return method(include=include, exclude=exclude)
   1037     return None
   1038 else:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/jupyter_integration.py:98, in JupyterIntegration._repr_mimebundle_(self, include, exclude, **_)
     96 include = set(include) if include is not None else {self._jupyter_mimetype}
     97 include -= set(exclude or [])
---> 98 return {mimetype: getattr(self, method_name)()
     99         for mimetype, method_name in MIME_TYPES.items()
    100         if mimetype in include}

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/jupyter_integration.py:98, in <dictcomp>(.0)
     96 include = set(include) if include is not None else {self._jupyter_mimetype}
     97 include -= set(exclude or [])
---> 98 return {mimetype: getattr(self, method_name)()
     99         for mimetype, method_name in MIME_TYPES.items()
    100         if mimetype in include}

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/jupyter_integration.py:112, in JupyterIntegration._repr_image_svg_xml(self)
    110 def _repr_image_svg_xml(self) -> str:
    111     """Return the rendered graph as SVG string."""
--> 112     return self.pipe(format='svg', encoding=SVG_ENCODING)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/piping.py:104, in Pipe.pipe(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)
     55 def pipe(self,
     56          format: typing.Optional[str] = None,
     57          renderer: typing.Optional[str] = None,
   (...)     61          engine: typing.Optional[str] = None,
     62          encoding: typing.Optional[str] = None) -> typing.Union[bytes, str]:
     63     """Return the source piped through the Graphviz layout command.
     64 
     65     Args:
   (...)    102         '<?xml version='
    103     """
--> 104     return self._pipe_legacy(format,
    105                              renderer=renderer,
    106                              formatter=formatter,
    107                              neato_no_op=neato_no_op,
    108                              quiet=quiet,
    109                              engine=engine,
    110                              encoding=encoding)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/_tools.py:185, in deprecate_positional_args.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    177     wanted = ', '.join(f'{name}={value!r}'
    178                        for name, value in deprecated.items())
    179     warnings.warn(f'The signature of {func_name} will be reduced'
    180                   f' to {supported_number} positional arg{s_}{qualification}'
    181                   f' {list(supported)}: pass {wanted} as keyword arg{s_}',
    182                   stacklevel=stacklevel,
    183                   category=category)
--> 185 return func(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/piping.py:121, in Pipe._pipe_legacy(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)
    112 @_tools.deprecate_positional_args(supported_number=1, ignore_arg='self')
    113 def _pipe_legacy(self,
    114                  format: typing.Optional[str] = None,
   (...)    119                  engine: typing.Optional[str] = None,
    120                  encoding: typing.Optional[str] = None) -> typing.Union[bytes, str]:
--> 121     return self._pipe_future(format,
    122                              renderer=renderer,
    123                              formatter=formatter,
    124                              neato_no_op=neato_no_op,
    125                              quiet=quiet,
    126                              engine=engine,
    127                              encoding=encoding)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/piping.py:149, in Pipe._pipe_future(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)
    146 if encoding is not None:
    147     if codecs.lookup(encoding) is codecs.lookup(self.encoding):
    148         # common case: both stdin and stdout need the same encoding
--> 149         return self._pipe_lines_string(*args, encoding=encoding, **kwargs)
    150     try:
    151         raw = self._pipe_lines(*args, input_encoding=self.encoding, **kwargs)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/piping.py:212, in pipe_lines_string(engine, format, input_lines, encoding, renderer, formatter, neato_no_op, quiet)
    206 cmd = dot_command.command(engine, format,
    207                           renderer=renderer,
    208                           formatter=formatter,
    209                           neato_no_op=neato_no_op)
    210 kwargs = {'input_lines': input_lines, 'encoding': encoding}
--> 212 proc = execute.run_check(cmd, capture_output=True, quiet=quiet, **kwargs)
    213 return proc.stdout

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/execute.py:81, in run_check(cmd, input_lines, encoding, quiet, **kwargs)
     79 except OSError as e:
     80     if e.errno == errno.ENOENT:
---> 81         raise ExecutableNotFound(cmd) from e
     82     raise
     84 if not quiet and proc.stderr:

ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH
<graphviz.graphs.Digraph at 0x7f9abc9b9ed0>

Input, Context, and Output Layers#

Next, we “hook” up the EMComposition to the input, output and context layer.

EGO

We start with defining the layers

🎯 Exercise 4

Before defining the layers, make sure you understand the in and output of the model:

  • Although the episodic memory composition has three “memory slot”, our training set only consists of a stream of a single state. How can we use this single state

state_input_layer = pnl.ProcessingMechanism(name=state_name, input_shapes=state_size)

previous_state_layer = pnl.ProcessingMechanism(name=previous_state_name, input_shapes=state_size)

context_layer = pnl.TransferMechanism(name=context_name,
                                  input_shapes=state_size,
                                  function=pnl.Tanh,
                                  integrator_mode=True,
                                  integration_rate=.69)

# The output layer:
prediction_layer = pnl.ProcessingMechanism(name='PREDICTION', input_shapes=state_size)

After defining the layers, we need to specify the pathways between the layers. Before looking at the code below, think about which pathways (if any) are learned and which ones are fixed.

# Names for the input nodes of the EMComposition have the form: <node_name> + ' [QUERY]' or <node_name> + ' [VALUE]' or <node_name> + ' [RETRIEVED]' (see above)

QUERY = ' [QUERY]'
VALUE = ' [VALUE]'
RETRIEVED = ' [RETRIEVED]'

# Pathways
state_to_previous_state_pathway = [state_input_layer,
                                   pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX,
                                                         learnable=False),
                                   previous_state_layer]
state_to_context_pathway = [state_input_layer,
                            pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX,
                                                  learnable=False),
                            context_layer]
state_to_em_pathway = [state_input_layer,
                       pnl.MappingProjection(sender=state_input_layer,
                                             receiver=em.nodes[state_name + VALUE],
                                             matrix=pnl.IDENTITY_MATRIX,
                                             learnable=False),
                       em]
previous_state_to_em_pathway = [previous_state_layer,
                                pnl.MappingProjection(sender=previous_state_layer,
                                                      receiver=em.nodes[previous_state_name + QUERY],
                                                      matrix=pnl.IDENTITY_MATRIX,
                                                      learnable=False),
                                em]
context_learning_pathway = [context_layer,
                            pnl.MappingProjection(sender=context_layer,
                                                  matrix=pnl.IDENTITY_MATRIX,
                                                  receiver=em.nodes[context_name + QUERY],
                                                  learnable=True),
                            em,
                            pnl.MappingProjection(sender=em.nodes[state_name + RETRIEVED],
                                                  receiver=prediction_layer,
                                                  matrix=pnl.IDENTITY_MATRIX,
                                                  learnable=False),
                            prediction_layer]

Now, we can create the composition

learning_rate = .5
loss_spec = pnl.Loss.BINARY_CROSS_ENTROPY
model_name = 'EGO'
device = pnl.CPU

ego_model = pnl.AutodiffComposition([state_to_previous_state_pathway,
                                    state_to_context_pathway,
                                    state_to_em_pathway,
                                    previous_state_to_em_pathway,
                                    context_learning_pathway],
                                   learning_rate=.5,
                                   loss_spec=pnl.Loss.BINARY_CROSS_ENTROPY,
                                   name='EGO',
                                   device=pnl.CPU)


ego_model.show_graph(output_fmt='jupyter')
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/execute.py:76, in run_check(cmd, input_lines, encoding, quiet, **kwargs)
     75         kwargs['stdout'] = kwargs['stderr'] = subprocess.PIPE
---> 76     proc = _run_input_lines(cmd, input_lines, kwargs=kwargs)
     77 else:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/execute.py:96, in _run_input_lines(cmd, input_lines, kwargs)
     95 def _run_input_lines(cmd, input_lines, *, kwargs):
---> 96     popen = subprocess.Popen(cmd, stdin=subprocess.PIPE, **kwargs)
     98     stdin_write = popen.stdin.write

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/subprocess.py:1026, in Popen.__init__(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds, user, group, extra_groups, encoding, errors, text, umask, pipesize, process_group)
   1023             self.stderr = io.TextIOWrapper(self.stderr,
   1024                     encoding=encoding, errors=errors)
-> 1026     self._execute_child(args, executable, preexec_fn, close_fds,
   1027                         pass_fds, cwd, env,
   1028                         startupinfo, creationflags, shell,
   1029                         p2cread, p2cwrite,
   1030                         c2pread, c2pwrite,
   1031                         errread, errwrite,
   1032                         restore_signals,
   1033                         gid, gids, uid, umask,
   1034                         start_new_session, process_group)
   1035 except:
   1036     # Cleanup if the child failed starting.

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/subprocess.py:1955, in Popen._execute_child(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, gid, gids, uid, umask, start_new_session, process_group)
   1954 if err_filename is not None:
-> 1955     raise child_exception_type(errno_num, err_msg, err_filename)
   1956 else:

FileNotFoundError: [Errno 2] No such file or directory: PosixPath('dot')

The above exception was the direct cause of the following exception:

ExecutableNotFound                        Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/IPython/core/formatters.py:1036, in MimeBundleFormatter.__call__(self, obj, include, exclude)
   1033     method = get_real_method(obj, self.print_method)
   1035     if method is not None:
-> 1036         return method(include=include, exclude=exclude)
   1037     return None
   1038 else:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/jupyter_integration.py:98, in JupyterIntegration._repr_mimebundle_(self, include, exclude, **_)
     96 include = set(include) if include is not None else {self._jupyter_mimetype}
     97 include -= set(exclude or [])
---> 98 return {mimetype: getattr(self, method_name)()
     99         for mimetype, method_name in MIME_TYPES.items()
    100         if mimetype in include}

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/jupyter_integration.py:98, in <dictcomp>(.0)
     96 include = set(include) if include is not None else {self._jupyter_mimetype}
     97 include -= set(exclude or [])
---> 98 return {mimetype: getattr(self, method_name)()
     99         for mimetype, method_name in MIME_TYPES.items()
    100         if mimetype in include}

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/jupyter_integration.py:112, in JupyterIntegration._repr_image_svg_xml(self)
    110 def _repr_image_svg_xml(self) -> str:
    111     """Return the rendered graph as SVG string."""
--> 112     return self.pipe(format='svg', encoding=SVG_ENCODING)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/piping.py:104, in Pipe.pipe(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)
     55 def pipe(self,
     56          format: typing.Optional[str] = None,
     57          renderer: typing.Optional[str] = None,
   (...)     61          engine: typing.Optional[str] = None,
     62          encoding: typing.Optional[str] = None) -> typing.Union[bytes, str]:
     63     """Return the source piped through the Graphviz layout command.
     64 
     65     Args:
   (...)    102         '<?xml version='
    103     """
--> 104     return self._pipe_legacy(format,
    105                              renderer=renderer,
    106                              formatter=formatter,
    107                              neato_no_op=neato_no_op,
    108                              quiet=quiet,
    109                              engine=engine,
    110                              encoding=encoding)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/_tools.py:185, in deprecate_positional_args.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    177     wanted = ', '.join(f'{name}={value!r}'
    178                        for name, value in deprecated.items())
    179     warnings.warn(f'The signature of {func_name} will be reduced'
    180                   f' to {supported_number} positional arg{s_}{qualification}'
    181                   f' {list(supported)}: pass {wanted} as keyword arg{s_}',
    182                   stacklevel=stacklevel,
    183                   category=category)
--> 185 return func(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/piping.py:121, in Pipe._pipe_legacy(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)
    112 @_tools.deprecate_positional_args(supported_number=1, ignore_arg='self')
    113 def _pipe_legacy(self,
    114                  format: typing.Optional[str] = None,
   (...)    119                  engine: typing.Optional[str] = None,
    120                  encoding: typing.Optional[str] = None) -> typing.Union[bytes, str]:
--> 121     return self._pipe_future(format,
    122                              renderer=renderer,
    123                              formatter=formatter,
    124                              neato_no_op=neato_no_op,
    125                              quiet=quiet,
    126                              engine=engine,
    127                              encoding=encoding)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/piping.py:149, in Pipe._pipe_future(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)
    146 if encoding is not None:
    147     if codecs.lookup(encoding) is codecs.lookup(self.encoding):
    148         # common case: both stdin and stdout need the same encoding
--> 149         return self._pipe_lines_string(*args, encoding=encoding, **kwargs)
    150     try:
    151         raw = self._pipe_lines(*args, input_encoding=self.encoding, **kwargs)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/piping.py:212, in pipe_lines_string(engine, format, input_lines, encoding, renderer, formatter, neato_no_op, quiet)
    206 cmd = dot_command.command(engine, format,
    207                           renderer=renderer,
    208                           formatter=formatter,
    209                           neato_no_op=neato_no_op)
    210 kwargs = {'input_lines': input_lines, 'encoding': encoding}
--> 212 proc = execute.run_check(cmd, capture_output=True, quiet=quiet, **kwargs)
    213 return proc.stdout

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/graphviz/backend/execute.py:81, in run_check(cmd, input_lines, encoding, quiet, **kwargs)
     79 except OSError as e:
     80     if e.errno == errno.ENOENT:
---> 81         raise ExecutableNotFound(cmd) from e
     82     raise
     84 if not quiet and proc.stderr:

ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH
<graphviz.graphs.Digraph at 0x7f9985561310>

We also need to specify the learning pathway, which can be inferred from the paramaters we have set (setting the target in EMComposition and setting the context to em pathway as learnable):

learning_components = ego_model.infer_backpropagation_learning_pathways(pnl.ExecutionMode.PyTorch)

ego_model.add_projection(pnl.MappingProjection(sender=state_input_layer,
                                              receiver=learning_components[0],
                                              learnable=False))
(MappingProjection MappingProjection from STATE[OutputPort-0] to TARGET for PREDICTION[InputPort-0])

We also have to make sure the em is executed before the previous state and the context layer:

ego_model.scheduler.add_condition(em, pnl.BeforeNodes(previous_state_layer, context_layer))

Now, we are set to run the model:

trials = state_sequence(BLOCKED, 800, 200)

ego_model.learn(inputs={state_name: trials},
                    learning_rate=.5,
                    execution_mode= pnl.ExecutionMode.PyTorch,
                  )
/opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/nonstateful/transferfunctions.py:3499: UserWarning: Softmax function: mask_threshold is set to tensor([0.0010], dtype=torch.float64), but input contains negative values. Masking will be applied to the magnitude of the input.
  warnings.warn(f"Softmax function: mask_threshold is set to {mask_threshold}, "
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[12], line 3
      1 trials = state_sequence(BLOCKED, 800, 200)
----> 3 ego_model.learn(inputs={state_name: trials},
      4                     learning_rate=.5,
      5                     execution_mode= pnl.ExecutionMode.PyTorch,
      6                   )

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:747, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    744             pass
    746 try:
--> 747     return func(*args, context=context, **kwargs)
    748 except TypeError as e:
    749     # context parameter may be passed as a positional arg
    750     if (
    751         f"{func.__name__}() got multiple values for argument"
    752         not in str(e)
    753     ):

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/autodiffcomposition.py:1761, in AutodiffComposition.learn(self, synch_projection_matrices_with_torch, synch_node_variables_with_torch, synch_node_values_with_torch, synch_results_with_torch, retain_torch_trained_outputs, retain_torch_targets, retain_torch_losses, context, base_context, skip_initialization, *args, **kwargs)
   1756 if execution_mode == pnlvm.ExecutionMode.PyTorch and not torch_available:
   1757     raise AutodiffCompositionError(f"'{self.name}.learn()' has been called with ExecutionMode.Pytorch, "
   1758                                    f"but Pytorch module ('torch') is not installed. "
   1759                                    f"Please install it with `pip install torch` or `pip3 install torch`")
-> 1761 return super().learn(*args,
   1762                      synch_with_pnl_options=synch_with_pnl_options,
   1763                      retain_in_pnl_options=retain_in_pnl_options,
   1764                      execution_mode=execution_mode,
   1765                      context=context,
   1766                      base_context=base_context,
   1767                      skip_initialization=skip_initialization,
   1768                      **kwargs)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:747, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    744             pass
    746 try:
--> 747     return func(*args, context=context, **kwargs)
    748 except TypeError as e:
    749     # context parameter may be passed as a positional arg
    750     if (
    751         f"{func.__name__}() got multiple values for argument"
    752         not in str(e)
    753     ):

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py:12101, in Composition.learn(self, inputs, targets, num_trials, epochs, learning_rate, minibatch_size, optimizations_per_minibatch, patience, min_delta, execution_mode, randomize_minibatches, call_before_minibatch, call_after_minibatch, context, base_context, skip_initialization, *args, **kwargs)
  12098 if optimizations_per_minibatch is None:
  12099     optimizations_per_minibatch = self.parameters.optimizations_per_minibatch._get(context)
> 12101 result = runner.run_learning(
  12102     inputs=inputs,
  12103     targets=targets,
  12104     num_trials=num_trials,
  12105     epochs=epochs,
  12106     learning_rate=learning_rate,
  12107     minibatch_size=minibatch_size,
  12108     optimizations_per_minibatch=optimizations_per_minibatch,
  12109     patience=patience,
  12110     min_delta=min_delta,
  12111     randomize_minibatches=randomize_minibatches,
  12112     call_before_minibatch=call_before_minibatch,
  12113     call_after_minibatch=call_after_minibatch,
  12114     context=context,
  12115     execution_mode=execution_mode,
  12116     skip_initialization=skip_initialization,
  12117     *args, **kwargs)
  12119 context.remove_flag(ContextFlags.LEARNING_MODE)
  12120 return result

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/compositionrunner.py:465, in CompositionRunner.run_learning(self, inputs, targets, num_trials, epochs, learning_rate, minibatch_size, optimizations_per_minibatch, patience, min_delta, randomize_minibatches, synch_with_pnl_options, retain_in_pnl_options, call_before_minibatch, call_after_minibatch, context, execution_mode, skip_initialization, **kwargs)
    462 run_trials = num_trials * stim_epoch if self._is_llvm_mode else None
    464 # IMPLEMENTATION NOTE: for autodiff composition, the following executes a MINIBATCH's worth of training
--> 465 self._composition.run(inputs=minibatched_input,
    466                       num_trials=run_trials,
    467                       skip_initialization=skip_initialization,
    468                       skip_analyze_graph=True,
    469                       optimizations_per_minibatch=optimizations_per_minibatch,
    470                       synch_with_pnl_options=synch_with_pnl_options,
    471                       retain_in_pnl_options=retain_in_pnl_options,
    472                       execution_mode=execution_mode,
    473                       context=context,
    474                       **kwargs)
    475 skip_initialization = True
    477 if execution_mode is ExecutionMode.PyTorch:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:747, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    744             pass
    746 try:
--> 747     return func(*args, context=context, **kwargs)
    748 except TypeError as e:
    749     # context parameter may be passed as a positional arg
    750     if (
    751         f"{func.__name__}() got multiple values for argument"
    752         not in str(e)
    753     ):

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/autodiffcomposition.py:2033, in AutodiffComposition.run(self, synch_projection_matrices_with_torch, synch_node_variables_with_torch, synch_node_values_with_torch, synch_results_with_torch, retain_torch_trained_outputs, retain_torch_targets, retain_torch_losses, batched_results, context, *args, **kwargs)
   2030     kwargs[RETAIN_IN_PNL_OPTIONS] = retain_in_pnl_options
   2032 # Run AutodiffComposition
-> 2033 results = super(AutodiffComposition, self).run(*args, context=context, **kwargs)
   2035 if EXECUTION_MODE in kwargs and kwargs[EXECUTION_MODE] is pnlvm.ExecutionMode.PyTorch:
   2036     # Synchronize specified outcomes at end of run
   2037     pytorch_rep = self.parameters.pytorch_representation.get(context)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:747, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    744             pass
    746 try:
--> 747     return func(*args, context=context, **kwargs)
    748 except TypeError as e:
    749     # context parameter may be passed as a positional arg
    750     if (
    751         f"{func.__name__}() got multiple values for argument"
    752         not in str(e)
    753     ):

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py:11808, in Composition.run(self, inputs, num_trials, initialize_cycle_values, reset_stateful_functions_to, reset_stateful_functions_when, skip_initialization, clamp_input, runtime_params, call_before_time_step, call_after_time_step, call_before_pass, call_after_pass, call_before_trial, call_after_trial, termination_processing, skip_analyze_graph, report_output, report_params, report_progress, report_simulations, report_to_devices, animate, log, scheduler, scheduling_mode, execution_mode, default_absolute_time_unit, context, base_context, **kwargs)
  11804     execution_stimuli = None
  11806 # execute processing, passing stimuli for this trial
  11807 # IMPLEMENTATION NOTE: for autodiff, the following executes the forward pass for a single input
> 11808 trial_output = self.execute(inputs=execution_stimuli,
  11809                             scheduler=scheduler,
  11810                             termination_processing=termination_processing,
  11811                             call_before_time_step=call_before_time_step,
  11812                             call_before_pass=call_before_pass,
  11813                             call_after_time_step=call_after_time_step,
  11814                             call_after_pass=call_after_pass,
  11815                             reset_stateful_functions_to=reset_stateful_functions_to,
  11816                             context=context,
  11817                             base_context=base_context,
  11818                             clamp_input=clamp_input,
  11819                             runtime_params=runtime_params,
  11820                             skip_initialization=True,
  11821                             execution_mode=execution_mode,
  11822                             report=report,
  11823                             report_num=report_num,
  11824                             **kwargs
  11825                             )
  11827 # ---------------------------------------------------------------------------------
  11828 # store the result of this execution in case it will be the final result
  11830 trial_output = copy_parameter_value(trial_output)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:747, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    744             pass
    746 try:
--> 747     return func(*args, context=context, **kwargs)
    748 except TypeError as e:
    749     # context parameter may be passed as a positional arg
    750     if (
    751         f"{func.__name__}() got multiple values for argument"
    752         not in str(e)
    753     ):

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/autodiffcomposition.py:1932, in AutodiffComposition.execute(self, inputs, num_trials, minibatch_size, optimizations_per_minibatch, do_logging, scheduler, termination_processing, call_before_minibatch, call_after_minibatch, call_before_time_step, call_before_pass, call_after_time_step, call_after_pass, reset_stateful_functions_to, context, base_context, clamp_input, targets, optimizer_params, runtime_params, execution_mode, skip_initialization, synch_with_pnl_options, retain_in_pnl_options, report_output, report_params, report_progress, report_simulations, report_to_devices, report, report_num)
   1923 # Begin reporting of learning TRIAL:
   1924 report(self,
   1925        LEARN_REPORT,
   1926        # EXECUTE_REPORT,
   (...)   1929        content='trial_start',
   1930        context=context)
-> 1932 self._build_pytorch_representation(optimizer_params=optimizer_params,
   1933                                    learning_rate=self.parameters.learning_rate.get(context),
   1934                                    context=context, base_context=base_context)
   1935 trained_output_values, all_output_values = \
   1936                                 self.autodiff_forward(inputs=autodiff_inputs,
   1937                                                       targets=autodiff_targets,
   (...)   1941                                                       scheduler=scheduler,
   1942                                                       context=context)
   1943 execution_phase = context.execution_phase

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:747, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    744             pass
    746 try:
--> 747     return func(*args, context=context, **kwargs)
    748 except TypeError as e:
    749     # context parameter may be passed as a positional arg
    750     if (
    751         f"{func.__name__}() got multiple values for argument"
    752         not in str(e)
    753     ):

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/autodiffcomposition.py:1267, in AutodiffComposition._build_pytorch_representation(self, learning_rate, optimizer_params, context, new, base_context)
   1264     pass
   1265 else:
   1266     # Otherwise, just update it
-> 1267     pytorch_rep._update_optimizer_params(old_opt,
   1268                                          optimizer_params,
   1269                                          Context(source=ContextFlags.METHOD,
   1270                                                  runmode=context.runmode,
   1271                                                  execution_id=context.execution_id))
   1272 # Set up loss function
   1273 if self.loss_function is not None:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/pytorchwrappers.py:844, in PytorchCompositionWrapper._update_optimizer_params(self, optimizer, optimizer_params_user_specs, context)
    839 if source == CONSTRUCTOR and self.optimizer:
    840     # If user has specified dict with learning_rates in call to _build_pytorch_representation,
    841     #    need to update the construct_param_groups with specififed values
    842     self._update_constructor_param_groups(self.composition, optimizer_params_user_specs)
--> 844 self._assign_learning_rates(optimizer,
    845                             optimizer_params_user_parsed,
    846                             optimizer_torch_params_full_with_specified,
    847                             run_time_default_learning_rate,
    848                             source,
    849                             context)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/pytorchwrappers.py:1031, in PytorchCompositionWrapper._assign_learning_rates(self, optimizer, optimizer_params_user_parsed, optimizer_torch_params_full_with_specified, run_time_default_learning_rate, source, context)
   1029 default_learning_rate = old_param_group['lr']
   1030 for param in old_param_group['params']:
-> 1031     projection = self._torch_params_to_projections(old_param_groups)[param]
   1032     specified_learning_rate = (
   1033         self._get_specified_learning_rate_for_param(param, projection,
   1034                                                     optimizer_params_user_parsed,
   1035                                                     run_time_default_learning_rate,
   1036                                                     source, context))
   1037     if specified_learning_rate is not False:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/pytorchwrappers.py:1304, in PytorchCompositionWrapper._torch_params_to_projections(self, param_groups)
   1302 # Give subclasses a chance for custom handling of param->projection mapping
   1303 for comp_wrapper in self.get_all_nested_composition_wrappers():
-> 1304     torch_params_to_projections.update(comp_wrapper._torch_params_to_projections(param_groups))
   1305 return torch_params_to_projections

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/pytorchwrappers.py:1301, in PytorchCompositionWrapper._torch_params_to_projections(self, param_groups)
   1299 for proj in self.wrapped_projections:
   1300     if proj.name in self._pnl_refs_to_torch_param_names:
-> 1301         torch_params_to_projections.update({self.get_torch_param_for_projection(proj): proj})
   1302 # Give subclasses a chance for custom handling of param->projection mapping
   1303 for comp_wrapper in self.get_all_nested_composition_wrappers():

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/pytorchwrappers.py:1250, in PytorchCompositionWrapper.get_torch_param_for_projection(self, projection)
   1248 projection_name = projection.name if isinstance(projection, Projection) else projection
   1249 param_name = self._pnl_refs_to_torch_param_names[projection_name].param_name
-> 1250 torch_long_param_name = self._torch_param_short_to_long_names_map[param_name]
   1251 for param_tuple in self.named_parameters():
   1252     # param_tuple is a tuple of (name, torch.nn.Parameter)
   1253     if torch_long_param_name == param_tuple[0]:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/pytorchwrappers.py:1312, in PytorchCompositionWrapper._torch_param_short_to_long_names_map(self)
   1307 @property
   1308 def _torch_param_short_to_long_names_map(self)->dict:
   1309     """Return map of short torch Parameter names to their full (hierarchical) names in named_parameters()
   1310     The "full" names should include prefixes for parameters in nested PytorchCompositionWrappers.
   1311     """
-> 1312     return {k.split('.')[-1]:k for k in [p[0] for p in self.named_parameters()]}

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/psyneulink/library/compositions/pytorchwrappers.py:1312, in <listcomp>(.0)
   1307 @property
   1308 def _torch_param_short_to_long_names_map(self)->dict:
   1309     """Return map of short torch Parameter names to their full (hierarchical) names in named_parameters()
   1310     The "full" names should include prefixes for parameters in nested PytorchCompositionWrappers.
   1311     """
-> 1312     return {k.split('.')[-1]:k for k in [p[0] for p in self.named_parameters()]}

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/torch/nn/modules/module.py:2706, in Module.named_parameters(self, prefix, recurse, remove_duplicate)
   2679 r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
   2680 
   2681 Args:
   (...)   2698 
   2699 """
   2700 gen = self._named_members(
   2701     lambda module: module._parameters.items(),
   2702     prefix=prefix,
   2703     recurse=recurse,
   2704     remove_duplicate=remove_duplicate,
   2705 )
-> 2706 yield from gen

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/torch/nn/modules/module.py:2641, in Module._named_members(self, get_members_fn, prefix, recurse, remove_duplicate)
   2635 memo = set()
   2636 modules = (
   2637     self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate)
   2638     if recurse
   2639     else [(prefix, self)]
   2640 )
-> 2641 for module_prefix, module in modules:
   2642     members = get_members_fn(module)
   2643     for k, v in members:

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/torch/nn/modules/module.py:2863, in Module.named_modules(self, memo, prefix, remove_duplicate)
   2861     continue
   2862 submodule_prefix = prefix + ("." if prefix else "") + name
-> 2863 yield from module.named_modules(
   2864     memo, submodule_prefix, remove_duplicate
   2865 )

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/torch/nn/modules/module.py:2863, in Module.named_modules(self, memo, prefix, remove_duplicate)
   2861     continue
   2862 submodule_prefix = prefix + ("." if prefix else "") + name
-> 2863 yield from module.named_modules(
   2864     memo, submodule_prefix, remove_duplicate
   2865 )

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/torch/nn/modules/module.py:2863, in Module.named_modules(self, memo, prefix, remove_duplicate)
   2861     continue
   2862 submodule_prefix = prefix + ("." if prefix else "") + name
-> 2863 yield from module.named_modules(
   2864     memo, submodule_prefix, remove_duplicate
   2865 )

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/torch/nn/modules/module.py:2818, in Module.named_modules(self, memo, prefix, remove_duplicate)
   2815     for _, module in self.named_modules():
   2816         yield module
-> 2818 def named_modules(
   2819     self,
   2820     memo: Optional[set["Module"]] = None,
   2821     prefix: str = "",
   2822     remove_duplicate: bool = True,
   2823 ):
   2824     r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
   2825 
   2826     Args:
   (...)   2851 
   2852     """
   2853     if memo is None:

KeyboardInterrupt: 
import matplotlib.pyplot as plt
import numpy as np

TOTAL_NUM_STIMS = len(trials)
TARGETS = np.array(trials[1:] + [one_hot_encode(0, 11)])
curriculum_type = BLOCKED

fig, axes = plt.subplots(1, 1, figsize=(12, 5))
# L1 of loss
axes.plot((np.abs(ego_model.results[1:TOTAL_NUM_STIMS, 2] - TARGETS[:TOTAL_NUM_STIMS - 1])).sum(-1))
axes.set_xlabel('Stimuli')
axes.set_ylabel('Loss')

plt.suptitle(f"{curriculum_type} Training")
plt.show()

🎯 Exercise 4

Run the model for the interleaved paradigm. What do you expect? Compare the two results and explain the differences.